package com.fishery.monitor.controller;

import com.fishery.entity.Result;
import com.fishery.entity.StatusCode;
import com.fishery.monitor.arima.ARIMA;
import com.fishery.monitor.pojo.MeteorologicalData;
import com.fishery.monitor.pojo.Warning;
import com.fishery.monitor.pojo.WaterData;
import com.fishery.monitor.service.MeteorologicalDataService;
import com.fishery.monitor.service.WarningService;
import com.fishery.monitor.service.WaterDataService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import net.sf.json.JSONObject;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.client.RestTemplate;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 控制器层
 *
 * @author Administrator
 */
@Api(tags = "预测接口",
		description = " 1、查出该基地的所有设备，将设备名字放在第一个下拉框。" +
				"2、第二个下拉框放设备的检测项，气象设备与水质设备不同，因为前端展示的是中文检测名，而数据库是字母，之间的装换要统一下，例如前端显示“温度”，则传给后端则是“airTemperature” " +
				"3、第三个下拉框放算法模型，水质算法有svm,LR,DT,RF四种；气象设备有两种arima,lstmrnn " +
				"4、第四个第五个下拉框放起始时间和结束时间，时间格式只能是  2019-10-08 04:47:47")
@RestController
@CrossOrigin
@RequestMapping("/datarecord/forecast")
public class ForecastController {

	@Autowired
	private MeteorologicalDataService meteorologicalDataService;

	@Autowired
	private WaterDataService waterDataService;

	@Autowired
	private WarningService warningService;
	/**
	 * 1、查出该基地的所有设备，将设备名字放在第一个下拉框。
	 * 2、第二个下拉框放设备的检测项，气象设备与水质设备不同，因为前端展示的是中文检测名，而数据库
	 *    是字母，之间的装换要统一下，例如前端显示“温度”，则传给后端则是“airTemperature”
	 * 3、第三个下拉框放算法模型，水质算法有svm,LR,DT,RF四种；气象设备有两种arima,lstmrnn
	 * 4、第四个第五个下拉框放起始时间和结束时间，时间格式只能是  2019-10-08 04:47:47
	 */

	/**
	 * 基于ARIMA的方法，预测对应的数据
	 *
	 * @return LayuiDataGridResult
	 */
	@ApiOperation("分页,基于ARIMA的方法，预测对应的数据")
	@RequestMapping(value = "/arima/{page}/{size}", method = RequestMethod.POST)
	@ResponseBody
	public Result getPredictedDataByARIMA(@RequestBody Map<String, String> condiction, @PathVariable int page, @PathVariable int size) {


		//获取查询数据到的  后面优化为只拿key,不拿全部字段
		List<MeteorologicalData> dataRecordList = meteorologicalDataService.querySource(condiction, page, size).getRecords();
		String checkItemName = condiction.get("checkItemName");


		//读取数据，将数据封装到数组中
		ArrayList<Double> arraylist = new ArrayList<>();
		ArrayList<Double> predictValueList = new ArrayList<>();

		//构建训练集dataSet 349 --> 100
		for (int i = dataRecordList.size() - 1; i >= 0; i--) {
			arraylist.add(dataRecordList.get(i).getCheckItem(checkItemName));
		}

		//数据样本条数大于0的时候才进行预测
		if (dataRecordList.size() > 0) {
			int num = dataRecordList.size() / 5;
			for (int i = 0; i < num; i++) {    //预测完一次就进行平移
				double[] dataArray = new double[arraylist.size() - num + i];//预测的样本数据，没预测完一个数据，数组就会扩充一个数据

				//预测完一个数据，窗口向右平移
				for (int j = 0; j < arraylist.size() - num + i; j++) {
					dataArray[j] = arraylist.get(j);
				}

				//进行预测
				ARIMA arima = new ARIMA(dataArray);
				int[] model = arima.getARIMAmodel();
				predictValueList.add(arima.aftDeal(arima.predictValue(model[0], model[1])));
				arraylist.add(arima.aftDeal(arima.predictValue(model[0], model[1])) * 1.0);
			}
		}
		//这里是判断阈值与预测结果之间的大小
		boolean warn = false;  //判断预测结果是否超出阈值，是就返回true
		Map whereMap = new HashMap();
		whereMap.put("equipmentId", condiction.get("equipmentId"));
		whereMap.put("channelName", checkItemNameChanage(checkItemName));
		List<Warning> search = warningService.findSearch(whereMap);
		if (!search.isEmpty()) {
			Warning warning = search.get(0);
			double maxV = warning.getMaxValue();
			double minV = warning.getMinValue();
			for (int i = 0; i < predictValueList.size(); i++) {
				if (predictValueList.get(i) > maxV || predictValueList.get(i) < minV) {
					warn = true;
					break;
				}
			}
		}

		Map<String, Object> map = new HashMap<>();

		map.put("orgindata", arraylist);
		map.put("predictdata", predictValueList);
		map.put("warn", warn);

//        Result responseMessage = new Result();
//        responseMessage.setCode(200);
//        responseMessage.setData(map);
//        return responseMessage;
		return new Result(true, StatusCode.OK, "预测成功", map);

	}

	/**
	 * 基于LSTM循环神经网络预测
	 */
	@ApiOperation("分页,基于LSTM循环神经网络预测")
	@RequestMapping(value = "/lstmrnn/{page}/{size}", method = RequestMethod.POST)
	@ResponseBody
	public Result getPredictedDataByLSTMRNN(@RequestBody Map<String, String> condiction, @PathVariable int page, @PathVariable int size) {

		List<MeteorologicalData> dataRecordList = meteorologicalDataService.querySource(condiction, page, size).getRecords();
		String checkItemName = condiction.get("checkItemName");


		//读取数据，将数据封装到数组中
		ArrayList<Double> arraylist = new ArrayList<Double>();
		ArrayList<Double> predictValueList = new ArrayList<>();
		double[] inputTrainDataArr = new double[dataRecordList.size() - 100];//训练输入数组
		double[] outputTrainDataArr = new double[dataRecordList.size() - 100];//训练输出数组
		double[] inputTestDataArr = new double[99];//测试输入数组
		double[] outputTestDataArr = new double[99];//测试输出数组


		//构建训练集dataSet 349 --> 100
		for (int i = dataRecordList.size() - 1, j = 0; i >= 100 && j < inputTrainDataArr.length; i--, j++) {
			inputTrainDataArr[j] = dataRecordList.get(i).getCheckItem(checkItemName);
		}
		for (int i = dataRecordList.size() - 1, j = 0; i >= 100 && j < inputTrainDataArr.length; i--, j++) {
			outputTrainDataArr[j] = dataRecordList.get(i - 1).getCheckItem(checkItemName);
		}
		double[][][] trainDataINDInputBox = {{inputTrainDataArr}};
		double[][][] trainDataINDOutputBox = {{outputTrainDataArr}};
		INDArray trainDataINDInput = Nd4j.create(trainDataINDInputBox);
		INDArray trainDataINDoutput = Nd4j.create(trainDataINDOutputBox);
		DataSet trainData = new DataSet();
		trainData.setFeatures(trainDataINDInput);
		trainData.setLabels(trainDataINDoutput);


		//构建测试集dataSet 100--1
		for (int j = 0, i = dataRecordList.size() - 250; i > 0 && j < outputTestDataArr.length; i--, j++) {
			inputTestDataArr[j] = dataRecordList.get(i).getCheckItem(checkItemName);
		}
		for (int j = 0, i = dataRecordList.size() - 251; i > 0 && j < outputTestDataArr.length; i--, j++) {
			outputTestDataArr[j] = dataRecordList.get(i - 1).getCheckItem(checkItemName);
		}
		double[][][] testDataINDInputBox = {{inputTestDataArr}};
		double[][][] testDataINDOutputBox = {{outputTestDataArr}};
		INDArray testDataINDInput = Nd4j.create(testDataINDInputBox);
		INDArray testDataINDoutput = Nd4j.create(testDataINDOutputBox);
		DataSet testData = new DataSet();
		testData.setFeatures(testDataINDInput);
		testData.setLabels(testDataINDoutput);


		//将数据及映射到0~1的范围
		NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
		normalizer.fitLabel(true);
		normalizer.fit(trainData);              //Collect training data statistics

		normalizer.transform(trainData);
		normalizer.transform(testData);

		//配置LSTM的RNN网络
		MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
				.seed(140)
				.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
				.weightInit(WeightInit.XAVIER)
				.updater(new Nesterovs(0.0015, 0.9))
				.list()
				.layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(1).nOut(5)
						.build())
				.layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
						.activation(Activation.IDENTITY).nIn(5).nOut(1).build())
				.build();

		MultiLayerNetwork net = new MultiLayerNetwork(conf);
		net.init();

		//开始训练
		int nEpochs = 200;
		System.out.println("开始训练...");
		for (int i = 0; i < nEpochs; i++) {
			net.fit(trainData);
		}
		System.out.println("训练结束");

		//开始预测---拟合数据，并非真实预测
		net.rnnTimeStep(trainData.getFeatures());
		INDArray predicted = net.rnnTimeStep(testData.getFeatures());

		//将映射的数据还原到初始数据
		normalizer.revertLabels(predicted);
		for (int i = 0; i < predicted.length(); i++) {
			predictValueList.add(predicted.getDouble(0, 0, i));
		}

		//真实预测数据
		double[] realPredictData = new double[50];
		realPredictData[0] = predicted.getDouble(0, 0, predicted.length() - 1);
		for (int i = 1; i < realPredictData.length; i++) {
			INDArray indArray = Nd4j.create(new double[]{realPredictData[i - 1]}, new int[]{1, 1});
			normalizer.transform(indArray);
			INDArray realPredicted = net.rnnTimeStep(indArray);
			normalizer.revertLabels(realPredicted);
			realPredictData[i] = realPredicted.getDouble(0);
		}
		//这里是判断阈值与预测结果之间的大小
		boolean warn = false; //判断预测结果是否超出阈值，是就返回true
		Map whereMap = new HashMap();
		whereMap.put("equipmentId", condiction.get("equipmentId"));
		whereMap.put("channelName", checkItemNameChanage(checkItemName));
		List<Warning> search = warningService.findSearch(whereMap);
		if (!search.isEmpty()) {
			Warning warning = search.get(0);
			double maxV = warning.getMaxValue();
			double minV = warning.getMinValue();
			for (int i = 0; i < predictValueList.size(); i++) {
				if (predictValueList.get(i) > maxV || predictValueList.get(i) < minV) {
					warn = true;
					break;
				}
			}
		}

		//返回数据

//        List<MeteorologicalData> orgindata = dataRecordList.subList(dataRecordList.size() - 250, dataRecordList.size() - 1);
//        List<MeteorologicalData> testdata = dataRecordList.subList(0, dataRecordList.size() - 250);
		Map<String, Object> map = new HashMap<>();
//        map.put("orgindata", orgindata);
		map.put("orgindata", inputTrainDataArr);
		map.put("predictdata", predictValueList);
		map.put("realPredictData", realPredictData);
		map.put("warn", warn);

//        Result responseMessage = new Result();
//        responseMessage.setCode(200);
//        responseMessage.setData(map);
//        return responseMessage;

		return new Result(true, StatusCode.OK, "预测成功", map);


	}

	/**
	 * 水质预测，调用第三方接口
	 * 需要的参数  检测项,设备id,算法模型,page,size
	 */
	@ApiOperation("分页,水质预测，调用第三方接口")
	@RequestMapping(value = "/water/{algorithm}/{page}/{size}", method = RequestMethod.POST)
	@ResponseBody
	public Result getPredictedDataBySVN(@RequestBody Map<String, String> condiction, @PathVariable String algorithm, @PathVariable int page, @PathVariable int size) {
		List<WaterData> dataRecordList = waterDataService.querySource(condiction, page, size).getRecords();
		String checkItemName = condiction.get("checkItemName");

		//读取数据，将数据封装到数组中
		ArrayList<ArrayList> arraylist = new ArrayList<>();
		ArrayList<Double> predictValueList = new ArrayList<>();
		//该集合的个数对应返回预测结果的个数
		ArrayList<ArrayList> realPredictValueList = new ArrayList<>();

		//构建训练集
		for (int i = dataRecordList.size() - 1; i >= 1; i--) {
			if (dataRecordList.get(i).getCheckItem(checkItemName) == null || dataRecordList.get(i - 1).getCheckItem(checkItemName) == null) {
				continue;
			}
			ArrayList<Double> arr1 = new ArrayList();
			arr1.add(dataRecordList.get(i).getCheckItem(checkItemName));
			arraylist.add(arr1);

			predictValueList.add(dataRecordList.get(i - 1).getCheckItem(checkItemName));
		}
		for (int i = arraylist.size() - 1; i > arraylist.size() - 11; i--) {
			ArrayList<Double> arr1 = new ArrayList();
			arr1.add((Double) arraylist.get(i).get(0));
			realPredictValueList.add(arr1);
		}

		String url = "http://120.78.14.141:9001/" + algorithm;
		JSONObject postData = new JSONObject();
		postData.put("feature", arraylist + ""); //特征
		postData.put("target", predictValueList + "");          //目标
		postData.put("prediction", realPredictValueList + "");        //预测

		RestTemplate restTemplate = new RestTemplate();
		JSONObject s = restTemplate.postForObject(url, postData, JSONObject.class);
		//封装返回的数据
		Map<String, Object> map = new HashMap<>();
		map.put("orgindata", predictValueList);
		map.put("predictdata", s.get("prediction"));

		//这里是判断阈值与预测结果之间的大小
		boolean warn = false; //判断预测结果是否超出阈值，是就返回true
		Map whereMap = new HashMap();
		whereMap.put("equipmentId", condiction.get("equipmentId"));
		whereMap.put("channelName", checkItemNameChanage(checkItemName));
		List<Warning> search = warningService.findSearch(whereMap);
		if (!search.isEmpty()) {
			Warning warning = search.get(0);
			String str = s.get("prediction").toString().substring(1, s.get("prediction").toString().length() - 1);
			String[] strings = str.split(",");
			for (int i = 0; i < strings.length; i++) {
				if (Double.parseDouble(strings[i]) > warning.getMaxValue() || Double.parseDouble(strings[i]) < warning.getMinValue()) {
					warn = true;
					break;
				}
			}
		}
		map.put("warn", warn);
		return new Result(true, StatusCode.OK, "预测成功", map);
	}

	/**
	 * 这里主要用于检测项与通道名称得转换
	 * checkItemName——>channelName
	 */
	public static String checkItemNameChanage(String checkItemName) {
		//气象通道转换
		if (checkItemName.equals("electric_energy")) {
			return "电能";
		} else if (checkItemName.equals("illumination")) {
			return "光照";
		} else if (checkItemName.equals("wind_speed")) {
			return "风速";
		} else if (checkItemName.equals("wind_direct")) {
			return "风向";
		} else if (checkItemName.equals("air_temperature")) {
			return "气温";
		} else if (checkItemName.equals("humidity")) {
			return "湿度";
		} else if (checkItemName.equals("rain")) {
			return "雨量";
		} else if (checkItemName.equals("soil_temperature")) {
			return "土温";
			//水质通道转换
		} else if (checkItemName.equals("dissolvedOxygen")) {
			return "溶解氧";
		} else if (checkItemName.equals("waterTemperature")) {
			return "水温";
		} else if (checkItemName.equals("phValue")) {
			return "pH";
		} else if (checkItemName.equals("ammoniaNitrogen")) {
			return "氨氮";
		} else if (checkItemName.equals("conductivity")) {
			return "电导率";
		} else if (checkItemName.equals("turbidity")) {
			return "浊度";
		} else if (checkItemName.equals("permanganateIndex")) {
			return "高猛酸盐指数";
		} else if (checkItemName.equals("phosphorus")) {
			return "总磷";
		} else if (checkItemName.equals("nitrogen")) {
			return "总氮";
		} else if (checkItemName.equals("chlorophyll")) {
			return "叶绿素α";
		} else if (checkItemName.equals("algalDensity")) {
			return "藻密度";
		} else if (checkItemName.equals("waterLevel")) {
			return "水位";
		} else {
			return "";
		}
	}
}